package com.menghao.rpc.consumer.handle.tcp;

import com.menghao.rpc.consumer.balance.LoadBalancer;
import com.menghao.rpc.consumer.balance.RandomLoadBalancer;
import com.menghao.rpc.consumer.handle.ReferenceAgent;
import com.menghao.rpc.consumer.model.ReferenceKey;
import com.menghao.rpc.consumer.model.RpcRequest;
import com.menghao.rpc.exception.InvokeException;
import com.menghao.rpc.netty.TcpConnectionContainer;
import com.menghao.rpc.netty.model.TcpConnection;
import com.menghao.rpc.spring.BeansManager;
import lombok.Getter;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * <p>ReferenceAgent T方式实现.</br>
 * <p>调用原始接口的任意方法会被该类的invoke方法代理：使用Netty发送请求</p>
 * <p>sourceInterface/implCode：唯一标识一个服务</p>
 *
 * @author MarvelCode
 */
public class TcpReferenceAgent implements ReferenceAgent {

    @Getter
    private Class sourceInterface;
    @Getter
    private String implCode;

    private List<String> providerHosts;

    private TcpConnectionContainer tcpConnectionContainer;

    private LoadBalancer defaultBalancer = new RandomLoadBalancer();

    private ReadWriteLock lock = new ReentrantReadWriteLock();

    public TcpReferenceAgent(ReferenceKey referenceKey) {
        this.sourceInterface = referenceKey.getSourceInterface();
        this.implCode = referenceKey.getName();
        this.tcpConnectionContainer = BeansManager.getInstance().getBeanByType(TcpConnectionContainer.class);
        this.providerHosts = new ArrayList<>();
    }

    @Override
    public Object invoke(Method method, Object[] args) {
        // 构造请求参数
        RpcRequest rpcRequest = makeParam(method, args);
        // 负载均衡选取Tcp连接
        TcpConnection tcpConnection = select();
        // 构建调用上下文
        InvocationContext invocationContext = new InvocationContext(tcpConnection, rpcRequest);
        // 执行调用
        invocationContext.execute();
        // 阻塞获取结果
        return invocationContext.get();
    }

    @Override
    public void setProviderHosts(List<String> hosts) {
        lock.writeLock().lock();
        try {
            // 刷新Tcp连接
            refreshConnection(providerHosts, hosts);
            this.providerHosts = hosts;
        } finally {
            lock.writeLock().unlock();
        }
    }


    private TcpConnection select() {
        lock.readLock().lock();
        try {
            if (providerHosts == null || providerHosts.size() == 0) {
                throw new InvokeException("There are currently no service providers available");
            }
            // 负载均衡
            String host = defaultBalancer.select(providerHosts);
            String[] info = host.split(":");
            return tcpConnectionContainer.get(info[0], Integer.valueOf(info[1]));
        } finally {
            lock.readLock().unlock();
        }
    }

    private void refreshConnection(List<String> lastHost, List<String> nowHost) {
        Set<String> commonHost = new HashSet<>(lastHost);
        // 当前存活机器与上次存活机器交集
        commonHost.retainAll(nowHost);
        Set<String> lostHost = new HashSet<>(lastHost);
        Set<String> addHost = new HashSet<>(nowHost);
        // 当前存活机器与交集的差集，得出新增的机器
        addHost.removeAll(commonHost);
        // 上次存活机器与交集的差集，得出下线的机器
        lostHost.removeAll(commonHost);
        // 下线的机器，将关闭并移除Tcp连接
        for (String host : lostHost) {
            String[] info = host.split(":");
            tcpConnectionContainer.remove(info[0], Integer.valueOf(info[1]));
        }
        // 新增的机器，将新建Tcp连接
        for (String host : addHost) {
            String[] info = host.split(":");
            tcpConnectionContainer.register(info[0], Integer.valueOf(info[1]));
        }
    }

}
